# -*- coding: utf-8 -*-
"""
@Time    : 2024/7/10 20:28 
@Author  : ZhangShenao 
@File    : qa_retriever.py 
@Desc    : 问答检索器
"""

from langchain.chains.retrieval_qa.base import RetrievalQA, BaseRetrievalQA
from langchain.retrievers import MultiQueryRetriever
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.vectorstores import VST


class QARetriever:
    """问答检索器"""

    def __init__(self):
        """构造函数"""
        pass

    @staticmethod
    def new_retriever_chain(vector_store: VST) -> BaseRetrievalQA:
        # 创建LLM实例
        llm = ChatZhipuAI(model_name="glm-4-air")

        # 创建MultiQueryRetriever实例
        retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vector_store.as_retriever(), llm=llm)

        # 创建Retrieval QA Chain
        chain = RetrievalQA.from_chain_type(llm, retriever=retriever_from_llm)
        return chain
